// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

// Implements logic for lowering StableHLO dot product ops to Linalg dialect.
// These patterns are separated out to their own file to save on the compilation
// times, given that we instantiate a large number of class templates here.

#include "compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.h"
#include "compiler/plugins/input/StableHLO/Conversion/Rewriters.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "stablehlo/dialect/StablehloOps.h"

namespace mlir::iree_compiler::stablehlo {
namespace {
enum class DotOperationType {
  kVectorDot = 0,
  kMatrixVector,
  kVectorMatrix,
  kMatrixMatrix,
  kUnsupported
};

DotOperationType getDotOperationType(mlir::stablehlo::DotOp dotOp) {
  ArrayRef<int64_t> lhsShape =
      cast<ShapedType>(dotOp.getLhs().getType()).getShape();
  ArrayRef<int64_t> rhsShape =
      cast<ShapedType>(dotOp.getRhs().getType()).getShape();
  auto shapeMatches = [](int64_t a, int64_t b) {
    return ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b;
  };
  if (lhsShape.size() == 1 && rhsShape.size() == 1 &&
      shapeMatches(lhsShape[0], rhsShape[0])) {
    return DotOperationType::kVectorDot;
  }
  if (lhsShape.size() == 2 && rhsShape.size() == 1 &&
      shapeMatches(lhsShape[1], rhsShape[0])) {
    return DotOperationType::kMatrixVector;
  }
  if (lhsShape.size() == 1 && rhsShape.size() == 2 &&
      shapeMatches(lhsShape[0], rhsShape[0])) {
    return DotOperationType::kVectorMatrix;
  }
  if (lhsShape.size() == 2 && rhsShape.size() == 2 &&
      shapeMatches(lhsShape[1], rhsShape[0])) {
    return DotOperationType::kMatrixMatrix;
  }
  return DotOperationType::kUnsupported;
}

SmallVector<Value, 2> getDotOpEmptyTensorDynSizes(OpBuilder &b, Location loc,
                                                  Value lhs, Value rhs,
                                                  DotOperationType type) {
  SmallVector<Value, 2> dynShape;
  switch (type) {
  case DotOperationType::kMatrixMatrix: {
    if (llvm::cast<ShapedType>(lhs.getType()).isDynamicDim(0))
      dynShape.push_back(b.create<tensor::DimOp>(loc, lhs, 0));
    if (llvm::cast<ShapedType>(rhs.getType()).isDynamicDim(1))
      dynShape.push_back(b.create<tensor::DimOp>(loc, rhs, 1));
    break;
  }
  case DotOperationType::kMatrixVector: {
    if (llvm::cast<ShapedType>(lhs.getType()).isDynamicDim(0))
      dynShape.push_back(b.create<tensor::DimOp>(loc, lhs, 0));
    break;
  }
  case DotOperationType::kVectorMatrix: {
    if (llvm::cast<ShapedType>(rhs.getType()).isDynamicDim(1))
      dynShape.push_back(b.create<tensor::DimOp>(loc, rhs, 1));
    break;
  }
  case DotOperationType::kVectorDot:
  case DotOperationType::kUnsupported:
    break;
  }
  return dynShape;
}

template <DotOperationType op_type, typename LinalgOp>
struct DotOpConversion final : OpConversionPattern<mlir::stablehlo::DotOp> {
  using OpConversionPattern<mlir::stablehlo::DotOp>::OpConversionPattern;
  using OpAdaptor = mlir::stablehlo::DotOp::Adaptor;

  LogicalResult
  matchAndRewrite(mlir::stablehlo::DotOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const final {
    if (failed(verifyHloOpBufferOrTensorSemantics(op))) {
      return failure();
    }
    if (getDotOperationType(op) != op_type)
      return failure();

    Location loc = op.getLoc();
    // Convert unsigned to signed. This works because signed and unsigned
    // integer matmul is the same operation in two's complement.
    auto outputType =
        cast<ShapedType>(getTypeConverter()->convertType(op.getType()));
    SmallVector<Value, 2> dynShape = getDotOpEmptyTensorDynSizes(
        rewriter, loc, adaptor.getLhs(), adaptor.getRhs(), op_type);
    Value emptyTensor =
        !sparse_tensor::getSparseTensorEncoding(outputType)
            ? getEmptyTensor(rewriter, loc, outputType, dynShape)
            : getEmptySparseTensor(rewriter, loc, outputType, dynShape);
    Value zeroTensor = fillTensorWithZeros(rewriter, loc, emptyTensor);
    rewriter.replaceOpWithNewOp<LinalgOp>(
        op, TypeRange{outputType},
        ValueRange{adaptor.getLhs(), adaptor.getRhs()}, ValueRange{zeroTensor},
        linalg::getPrunedAttributeList(op));
    return success();
  }
};

struct DotGeneralBatchMatMulOpConversion final
    : OpConversionPattern<mlir::stablehlo::DotGeneralOp> {
  using OpConversionPattern::OpConversionPattern;

  LogicalResult
  matchAndRewrite(mlir::stablehlo::DotGeneralOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const final {
    if (failed(verifyHloOpBufferOrTensorSemantics(op))) {
      return failure();
    }
    if (llvm::cast<RankedTensorType>(op.getType()).getRank() != 3) {
      return rewriter.notifyMatchFailure(op, "expected a batch matmul");
    }

    mlir::stablehlo::DotDimensionNumbersAttr dimNumbers =
        op.getDotDimensionNumbers();
    ArrayRef<int64_t> lhsBatchingDims = dimNumbers.getLhsBatchingDimensions();
    ArrayRef<int64_t> rhsBatchingDims = dimNumbers.getRhsBatchingDimensions();
    ArrayRef<int64_t> lhsContractingDims =
        dimNumbers.getLhsContractingDimensions();
    ArrayRef<int64_t> rhsContractingDims =
        dimNumbers.getRhsContractingDimensions();
    if (lhsBatchingDims.size() != 1 || lhsBatchingDims[0] != 0) {
      return rewriter.notifyMatchFailure(
          op, "expected lhs batching dimensions exactly {0}");
    }
    if (rhsBatchingDims.size() != 1 || rhsBatchingDims[0] != 0) {
      return rewriter.notifyMatchFailure(
          op, "expected rhs batching dimensions exactly {0}");
    }
    if (lhsContractingDims.size() != 1 || lhsContractingDims[0] != 2) {
      return rewriter.notifyMatchFailure(
          op, "expected lhs contracting dimensions exactly {2}");
    }
    if (rhsContractingDims.size() != 1 || rhsContractingDims[0] != 1) {
      return rewriter.notifyMatchFailure(
          op, "expected rhs contracting dimensions exactly {1}");
    }

    Location loc = op.getLoc();
    // Convert unsigned to signed. This works because signed and unsigned
    // integer matmul is the same operation in two's complement.
    auto outputType =
        cast<ShapedType>(typeConverter->convertType(op.getType()));
    Value emptyTensor =
        getEmptyTensorFor(rewriter, loc, outputType, op, adaptor.getOperands());
    Value zeroTensor = fillTensorWithZeros(rewriter, loc, emptyTensor);
    Operation *linalgOp = rewriter.create<linalg::BatchMatmulOp>(
        loc, /*resultTensorTypes=*/TypeRange{outputType},
        /*inputs=*/ValueRange{adaptor.getLhs(), adaptor.getRhs()},
        /*outputBuffers=*/ValueRange{zeroTensor},
        linalg::getPrunedAttributeList(op));

    rewriter.replaceOp(op, linalgOp->getResults());
    return success();
  }
};

struct DotGeneralOpConversion final
    : OpConversionPattern<mlir::stablehlo::DotGeneralOp> {
  using OpConversionPattern::OpConversionPattern;
  LogicalResult
  matchAndRewrite(mlir::stablehlo::DotGeneralOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const final {
    if (failed(verifyHloOpBufferOrTensorSemantics(op))) {
      return failure();
    }

    // Get various dimension iterator information
    mlir::stablehlo::DotDimensionNumbersAttr dimNumbers =
        op.getDotDimensionNumbers();
    ArrayRef<int64_t> lhsBatchingDims = dimNumbers.getLhsBatchingDimensions();
    ArrayRef<int64_t> rhsBatchingDims = dimNumbers.getRhsBatchingDimensions();
    ArrayRef<int64_t> lhsContractingDims =
        dimNumbers.getLhsContractingDimensions();
    ArrayRef<int64_t> rhsContractingDims =
        dimNumbers.getRhsContractingDimensions();

    // Get shape information and initialize output
    assert(lhsContractingDims.size() == rhsContractingDims.size() &&
           "number of contracting dims must be equal");
    size_t numContracting = lhsContractingDims.size();
    // Convert unsigned to signed. This works because signed and unsigned
    // integer matmul is the same operation in two's complement.
    auto outputType =
        cast<ShapedType>(typeConverter->convertType(op.getType()));
    size_t targetRank = outputType.getRank();
    size_t totalLoopCount = numContracting + targetRank;

    int64_t lhsRank =
        llvm::cast<ShapedType>(adaptor.getLhs().getType()).getRank();
    size_t lhsExtraDims =
        lhsRank - lhsBatchingDims.size() - lhsContractingDims.size();
    int64_t rhsRank =
        llvm::cast<ShapedType>(adaptor.getRhs().getType()).getRank();

    Location loc = op.getLoc();
    Value emptyTensor =
        getEmptyTensorFor(rewriter, loc, outputType, op, adaptor.getOperands());
    Value zeroTensor = fillTensorWithZeros(rewriter, loc, emptyTensor);
    SmallVector<AffineMap, 3> indexingMaps;

    auto getMap = [&](int64_t rank, ArrayRef<int64_t> batchingDims,
                      ArrayRef<int64_t> contractingDims, size_t extraDims) {
      llvm::SmallVector<AffineExpr> indices(rank);
      for (const auto &i : llvm::enumerate(batchingDims)) {
        indices[i.value()] = rewriter.getAffineDimExpr(i.index());
      }
      for (const auto &i : llvm::enumerate(contractingDims)) {
        indices[i.value()] = rewriter.getAffineDimExpr(i.index() + targetRank);
      }
      for (int i = 0; i < rank; ++i) {
        if (!indices[i]) {
          indices[i] = rewriter.getAffineDimExpr(extraDims++);
        }
      }
      indexingMaps.push_back(AffineMap::get(/*dimCount=*/totalLoopCount,
                                            /*symbolCount=*/0, indices,
                                            op->getContext()));
    };
    getMap(lhsRank, lhsBatchingDims, lhsContractingDims,
           lhsBatchingDims.size());
    getMap(rhsRank, rhsBatchingDims, rhsContractingDims,
           rhsBatchingDims.size() + lhsExtraDims);

    {
      SmallVector<AffineExpr> dimExprs;
      dimExprs.reserve(targetRank);
      for (unsigned i = 0; i < targetRank; ++i)
        dimExprs.push_back(rewriter.getAffineDimExpr(i));
      indexingMaps.push_back(AffineMap::get(/*dimCount=*/totalLoopCount,
                                            /*symbolCount=*/0, dimExprs,
                                            op.getContext()));
    }

    Operation *linalgOp = rewriter.create<linalg::GenericOp>(
        loc, /*resultTensorTypes=*/TypeRange{outputType},
        /*inputs=*/ValueRange{adaptor.getLhs(), adaptor.getRhs()},
        /*outputBuffers=*/ValueRange{zeroTensor}, indexingMaps,
        getParallelAndReductionIterators(
            /*nLoops=*/totalLoopCount,
            /*nReduction=*/numContracting),
        [](OpBuilder &b, Location loc, ValueRange) {
          ImplicitLocOpBuilder builder(loc, b);
          linalg::MatmulOp::regionBuilder(builder, *b.getInsertionBlock(), {});
        },
        linalg::getPrunedAttributeList(op));

    rewriter.replaceOp(op, linalgOp->getResults());
    return success();
  }
};

} // namespace

namespace detail {
void populateStableHloDotProdToLinalgConversionPatterns(
    MLIRContext *context, TypeConverter &typeConverter,
    RewritePatternSet *patterns) {
  // Ensure specialized patterns are higher priority than their generic
  // versions.
  patterns
      ->add<DotOpConversion<DotOperationType::kMatrixMatrix, linalg::MatmulOp>,
            DotOpConversion<DotOperationType::kMatrixVector, linalg::MatvecOp>,
            DotOpConversion<DotOperationType::kVectorMatrix, linalg::VecmatOp>,
            DotOpConversion<DotOperationType::kVectorDot, linalg::DotOp>,
            DotGeneralBatchMatMulOpConversion>(typeConverter, context,
                                               PatternBenefit(2));
  patterns->add<DotGeneralOpConversion>(typeConverter, context,
                                        PatternBenefit(1));
}
} // namespace detail
} // namespace mlir::iree_compiler::stablehlo
